/**
 * This program is free software, you can redistribute it and/or modify.
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/* !
 * \file swi_glu_quant_static.h
 * \brief
 */
#ifndef SWI_GLU_QUANT_STATIC_H
#define SWI_GLU_QUANT_STATIC_H

#include "swi_glu_quant_base.h"

namespace SwiGluQuantOpt {
using namespace AscendC;

template <typename inType, typename outType, QuantType quantType>
class SwiGluQuantStatic : public SwiGluQuantBase {
public:
    __aicore__ inline SwiGluQuantStatic(TPipe *pipe)
    {
        pPipe = pipe;
    }

    __aicore__ inline void Init(GM_ADDR input_gm, GM_ADDR smooth_scales, GM_ADDR offsets, GM_ADDR group_index,
        GM_ADDR y_gm, GM_ADDR scale_gm, GM_ADDR workspace, const SwiGluQuantTilingData *__restrict tilingData)
    {
        ParseTilingData(tilingData);
        InitParams(sizeof(inType), sizeof(outType));
        InitBaseBuffer();
        InitAndSetBuffer(input_gm, smooth_scales, offsets, group_index, y_gm, scale_gm);
    }

    __aicore__ inline void Process()
    {
        GroupCopyIn();
        SyncAll();
        groupLocal = groupQueue.DeQue<int32_t>();
        DuplicateConst();
        ProcessCoreMultiUbMulti();

        groupQueue.FreeTensor(groupLocal);
    }

private:
    __aicore__ inline void InitAndSetBuffer(GM_ADDR input_gm, GM_ADDR smooth_scales, GM_ADDR offsets,
        GM_ADDR group_index, GM_ADDR y_gm, GM_ADDR scale_gm)
    {
        // gm数据
        xGm.SetGlobalBuffer((__gm__ inType *)input_gm, SPLIT_NUM * tilingData_.rowLen * tilingData_.colLen);
        yGm.SetGlobalBuffer((__gm__ int8_t *)y_gm, tilingData_.rowLen * tilingData_.colLen);
        yGmInt4.SetGlobalBuffer((__gm__ int4b_t *)y_gm, tilingData_.rowLen * tilingData_.colLen);
        group_index_Gm.SetGlobalBuffer((__gm__ int32_t *)group_index, tilingData_.groupLen);
        if constexpr (quantType == QuantType::STATIC_PER_TENSOR) {
            smooth_scales_Gm.SetGlobalBuffer((__gm__ float *)smooth_scales, tilingData_.groupLen * tilingData_.colLen);
            offsetsGm.SetGlobalBuffer((__gm__ float *)offsets, tilingData_.groupLen * tilingData_.colLen);
        } else {
            smooth_scales_Gm.SetGlobalBuffer((__gm__ float *)smooth_scales, tilingData_.groupLen);
            offsetsGm.SetGlobalBuffer((__gm__ float *)offsets, tilingData_.groupLen);
        }
        scale_Gm.SetGlobalBuffer((__gm__ float *)scale_gm, tilingData_.rowLen);

        // queue
        pPipe->InitBuffer(inQueueA, BUFFER_NUM, tileLength * sizeof(inType));
        pPipe->InitBuffer(inQueueB, BUFFER_NUM, tileLength * sizeof(inType));
        pPipe->InitBuffer(outQueueY, BUFFER_NUM, outLen * sizeof(outType));
        pPipe->InitBuffer(scaleQueue, BUFFER_NUM, basicRowLen * sizeof(float));
        pPipe->InitBuffer(groupQueue, BUFFER_NUM, alignedGroupLen * sizeof(int32_t));

        if constexpr (quantType == QuantType::STATIC_PER_TENSOR) {
            pPipe->InitBuffer(smoothQueue, BUFFER_NUM, alignedGroupLen * sizeof(float));
            pPipe->InitBuffer(offsetsQueue, BUFFER_NUM, alignedGroupLen * sizeof(float));
        } else {
            pPipe->InitBuffer(smoothQueue, BUFFER_NUM, sizeHalfLen * sizeof(float));
            pPipe->InitBuffer(offsetsQueue, BUFFER_NUM, sizeHalfLen * sizeof(float));
        }
        // 定义过程变量
        pPipe->InitBuffer(sharedTempBuf, tileLength * sizeof(float));
        pPipe->InitBuffer(tempBufferY, tileLength * sizeof(float));
        pPipe->InitBuffer(tempYUnit, sizeHalfLen * sizeof(float));
    }

    __aicore__ inline uint32_t GetSmoothIndex(uint32_t realRowNum, int32_t &groupNum, uint32_t smoothIndex)
    {
        // 获取符合条件的smooth_scales的index
        for (size_t index = smoothIndex; index < tilingData_.groupLen; index++) {
            groupNum = groupLocal.GetValue(index);
            if (groupNum >= realRowNum) {
                return index;
            }
        }
        return SMOOTH_INDEX_UPBOUND;
    }

    __aicore__ inline void GroupCopyIn()
    {
        LocalTensor<int32_t> groupLocal = groupQueue.AllocTensor<int32_t>();

        if (tilingData_.hasGroup == 1) {
            // 有分组才需要拷贝
            uint8_t rightPadding = alignedGroupLen - tilingData_.groupLen;
            DataCopyParams copyParams{ 1, (uint16_t)(tilingData_.groupLen * sizeof(int32_t)), 0, 0 };
            DataCopyPadParams padParams{ true, 0, rightPadding, 0 };
            DataCopyPad(groupLocal, group_index_Gm, copyParams, padParams);
            if (tilingData_.groupListType == 1) {
                SetFlag<HardEvent::MTE2_S>(EVENT_ID0);
                WaitFlag<HardEvent::MTE2_S>(EVENT_ID0);
                for (uint32_t i = 1; i < tilingData_.groupLen; i++) {
                    groupLocal.SetValue(i, groupLocal.GetValue(i - 1) + groupLocal.GetValue(i));
                }
            }
        }

        groupQueue.EnQue(groupLocal);
    }

    __aicore__ inline void SmoothCopyIn(uint32_t offset)
    {
        LocalTensor<float> smoothLocal = smoothQueue.AllocTensor<float>();
        if (smoothIsPad) {
            DataCopyParams copyParams{ 1, (uint16_t)(basicColLen * sizeof(float)), 0, 0 };
            DataCopyPadParams padParams{ false, 0, smoothRightPadding, 0 };
            DataCopyPad(smoothLocal, smooth_scales_Gm[offset], copyParams, padParams);
        } else {
            DataCopy(smoothLocal, smooth_scales_Gm[offset], basicColLen);
        }
        smoothQueue.EnQue(smoothLocal);
    }

    __aicore__ inline void OffsetsCopyIn(uint32_t offset)
    {
        LocalTensor<float> offsetsLocal = offsetsQueue.AllocTensor<float>();
        if (smoothIsPad) {
            DataCopyParams copyParams{ 1, (uint16_t)(basicColLen * sizeof(float)), 0, 0 };
            DataCopyPadParams padParams{ false, 0, smoothRightPadding, 0 };
            DataCopyPad(offsetsLocal, offsetsGm[offset], copyParams, padParams);
        } else {
            DataCopy(offsetsLocal, offsetsGm[offset], basicColLen);
        }
        offsetsQueue.EnQue(offsetsLocal);
    }

    __aicore__ inline void ProcessCoreMultiUbMulti()
    {
        uint32_t smoothIndex = 0;
        uint32_t offsetRow = 0;

        for (uint32_t ridx = 0; ridx < rowLoop; ridx++) {
            offsetRow = baseRow + ridx * basicRowLen;

            // 处理最后一行
            basicRowLenCal =
                static_cast<uint16_t>((ridx == rowLoop - 1) ? (rowLenPerCore - (rowLoop - 1) * basicRowLen) :
                                                              basicRowLen); // 每核处理的最后一个行循环单独处理
            ProcessCoreMultiUbMultiAlign(ridx, smoothIndex, offsetRow);
        }
    }

    __aicore__ inline void ComputeVecInGmOffset(uint32_t ridx)
    {
        if (coreIdx < headCoreNum) {
            offsetParam.tmpVecGmOffset = static_cast<uint64_t>(coreIdx) * rowLenPerCore * mergedColLen + ridx * basicRowLen * mergedColLen;
            splitCopyoutOffset = static_cast<uint64_t>(coreIdx) * rowLenPerCore * colLen + ridx * basicRowLen * basicColLen;
        } else {
            offsetParam.tmpVecGmOffset = static_cast<uint64_t>(headCoreNum) * tilingData_.rowLenPerHeadCore * mergedColLen +
                static_cast<uint64_t>(coreIdx - headCoreNum) * rowLenPerCore * mergedColLen + ridx * basicRowLen * mergedColLen;
            splitCopyoutOffset = static_cast<uint64_t>(headCoreNum) * tilingData_.rowLenPerHeadCore * colLen +
                static_cast<uint64_t>(coreIdx - headCoreNum) * rowLenPerCore * colLen + ridx * basicRowLen * basicColLen;
        }
    }

    __aicore__ inline void ProcessCoreMultiUbMultiAlign(uint32_t ridx, uint32_t &smoothIndex, uint32_t offsetRow)
    {
        DataCopyParams splitCopyinParams;
        DataCopyParams splitCopyoutParams;

        splitCopyinParams = { basicRowLenCal, (uint16_t)(basicColLen * sizeof(inType) / blockUnit),
            (uint16_t)((mergedColLen - basicColLen) * sizeof(inType) / blockUnit), 0 };

        uint16_t dstStride = (uint16_t)((colLen - basicColLen) * sizeof(outType));
        uint16_t blockLen = (uint16_t)(basicColLen * sizeof(outType));
        if (tilingData_.dstType == DT_INT4) {
            dstStride = CeilDiv(dstStride, (uint16_t)ONE_BYTE_INT4_NUM_TWO);
            blockLen =  CeilDiv(blockLen, (uint16_t)ONE_BYTE_INT4_NUM_TWO);
        }

        splitCopyoutParams = {basicRowLenCal, blockLen, 0, dstStride};

        ComputeVecInGmOffset(ridx);

        if (tilingData_.activateLeft == 1) {
            offsetParam.splitVecGmOffset1 = offsetParam.tmpVecGmOffset;
            offsetParam.splitVecGmOffset2 = offsetParam.splitVecGmOffset1 + tilingData_.colLen;
        } else {
            offsetParam.splitVecGmOffset2 = offsetParam.tmpVecGmOffset;
            offsetParam.splitVecGmOffset1 = offsetParam.splitVecGmOffset2 + tilingData_.colLen;
        }

        uint32_t smoothScalesOffset = smoothIndex * tilingData_.colLen;

        CopyIn(offsetParam, smoothScalesOffset, splitCopyinParams);
        Compute(offsetRow, smoothIndex);
        CopyOut(splitCopyoutOffset, splitCopyoutParams, ridx, basicRowLenCal);
    }

    __aicore__ inline void CopyIn(XxGluSingleTileOffsetParam &offsetParam, uint32_t smoothScalesOffset,
        DataCopyParams &splitCopyinParams)
    {
        LocalTensor<inType> aLocal = this->inQueueA.template AllocTensor<inType>();
        LocalTensor<inType> bLocal = this->inQueueB.template AllocTensor<inType>();

        if (isPad) {
            // Copy A
            DataCopyPadParams padParams{ false, 0, rightPadding, 0 };
            DataCopyPad(aLocal, this->xGm[offsetParam.splitVecGmOffset1], splitCopyinParams, padParams);
            // Copy B
            DataCopyPad(bLocal, this->xGm[offsetParam.splitVecGmOffset2], splitCopyinParams, padParams);
        } else {
            // Copy A
            DataCopy(aLocal, this->xGm[offsetParam.splitVecGmOffset1], splitCopyinParams);
            // Copy B
            DataCopy(bLocal, this->xGm[offsetParam.splitVecGmOffset2], splitCopyinParams);
        }

        this->inQueueA.template EnQue(aLocal);
        this->inQueueB.template EnQue(bLocal);
        // Copy Scales and Offsets
        if (quantType == QuantType::STATIC_PER_CHANNEL) {
            SmoothCopyIn(smoothScalesOffset);
            OffsetsCopyIn(smoothScalesOffset);
        }
    }

    __aicore__ inline void ComputePerChannelQuant(uint32_t offsetRow, uint32_t &smoothIndex)
    {
        uint32_t index = 0;
        uint32_t smoothOffset = 0;
        uint32_t realRowNum = 0;
        int32_t groupValue = tilingData_.rowLen;
        if (tilingData_.hasGroup == 1) {
            groupValue = groupLocal.GetValue(smoothIndex);
        }

        LocalTensor<int8_t> outLocal = outQueueY.AllocTensor<int8_t>();
        LocalTensor<float> tempFp32 = tempYUnit.Get<float>();
        LocalTensor<int32_t> tempInt32 = sharedTempBuf.Get<int32_t>(sizeHalfLen);
        auto tempHalf = tempFp32.ReinterpretCast<half>();
        LocalTensor<float> smoothLocal = smoothQueue.DeQue<float>();
        LocalTensor<float> offsetsLocal = offsetsQueue.DeQue<float>();
        for (int32_t i = 0; i < basicRowLenCal; i++) {
            index = i * sizeHalfLen;
            DataCopy(tempFp32, tmpYLocal[index], sizeHalfLen);

            realRowNum = offsetRow + i + 1;
            if (groupValue < realRowNum && smoothIndex != SMOOTH_INDEX_UPBOUND) {
                smoothIndex = GetSmoothIndex(realRowNum, groupValue, smoothIndex + 1);
                if (smoothIndex != SMOOTH_INDEX_UPBOUND) {
                    smoothQueue.FreeTensor(smoothLocal);
                    offsetsQueue.FreeTensor(offsetsLocal);
                    smoothOffset = smoothIndex * basicColLen;
                    SmoothCopyIn(smoothOffset);
                    OffsetsCopyIn(smoothOffset);
                    smoothLocal = smoothQueue.DeQue<float>();
                    offsetsLocal = offsetsQueue.DeQue<float>();
                }
            }

            if (smoothIndex != SMOOTH_INDEX_UPBOUND) {
                Mul(tempFp32, tempFp32, smoothLocal, basicColLen);
                PipeBarrier<PIPE_V>();
                Add(tempFp32, tempFp32, offsetsLocal, basicColLen);
                PipeBarrier<PIPE_V>();
            }
            CastQuantOut(tempFp32, tempInt32, tempHalf, outLocal, i);
        }
        smoothQueue.FreeTensor(smoothLocal);
        offsetsQueue.FreeTensor(offsetsLocal);
        outQueueY.template EnQue<int8_t>(outLocal);
    }

    __aicore__ inline void ComputePerTensorQuant(uint32_t offsetRow, uint32_t &smoothIndex)
    {
        uint32_t index = 0;
        uint32_t smoothOffset = 0;
        uint32_t realRowNum = 0;
        int32_t groupValue = tilingData_.rowLen;
        if (tilingData_.hasGroup == 1) {
            groupValue = groupLocal.GetValue(smoothIndex);
        }

        LocalTensor<int8_t> outLocal = outQueueY.AllocTensor<int8_t>();
        LocalTensor<float> tempFp32 = tempYUnit.Get<float>();
        LocalTensor<int32_t> tempInt32 = sharedTempBuf.Get<int32_t>(sizeHalfLen);
        LocalTensor<float> smoothLocal = smoothQueue.DeQue<float>();
        LocalTensor<float> offsetsLocal = offsetsQueue.DeQue<float>();
        auto tempHalf = tempFp32.ReinterpretCast<half>();
        float smoothVal = smoothLocal.GetValue(smoothIndex);
        float offsetsVal = offsetsLocal.GetValue(smoothIndex);
        for (int32_t i = 0; i < basicRowLenCal; i++) {
            index = i * sizeHalfLen;
            DataCopy(tempFp32, tmpYLocal[index], sizeHalfLen);
            realRowNum = offsetRow + i + 1;
            if (groupValue < realRowNum && smoothIndex != SMOOTH_INDEX_UPBOUND) {
                smoothIndex = GetSmoothIndex(realRowNum, groupValue, smoothIndex + 1);
                if (smoothIndex != SMOOTH_INDEX_UPBOUND) {
                    smoothVal = smoothLocal.GetValue(smoothIndex);
                    offsetsVal = offsetsLocal.GetValue(smoothIndex);
                }
            }

            if (smoothIndex != SMOOTH_INDEX_UPBOUND) {
                Muls(tempFp32, tempFp32, smoothVal, basicColLen);
                PipeBarrier<PIPE_V>();
                Adds(tempFp32, tempFp32, offsetsVal, basicColLen);
                PipeBarrier<PIPE_V>();
            }
            CastQuantOut(tempFp32, tempInt32, tempHalf, outLocal, i);
        }
        outQueueY.template EnQue<int8_t>(outLocal);
    }

    __aicore__ inline void Compute(uint32_t offsetRow, uint32_t &smoothIndex)
    {
        tmpYLocal = tempBufferY.Get<float>();
        LocalTensor<float> scaleLocal = scaleQueue.AllocTensor<float>();
        LocalTensor<float> tmpALocal = sharedTempBuf.Get<float>();
        LocalTensor<inType> aLocal = inQueueA.template DeQue<inType>();

        if constexpr (sizeof(inType) == sizeof(float)) {
            DataCopy(tmpALocal, aLocal, tileLength);
        } else {
            Cast(tmpALocal, aLocal, RoundMode::CAST_NONE, tileLength);
        }

        inQueueA.template FreeTensor(aLocal);
        Muls(tmpYLocal, tmpALocal, static_cast<float>(-1.0), tileLength);
        PipeBarrier<PIPE_V>();
        Exp(tmpYLocal, tmpYLocal, tileLength);
        PipeBarrier<PIPE_V>();
        Adds(tmpYLocal, tmpYLocal, static_cast<float>(1.0), tileLength);
        PipeBarrier<PIPE_V>();
        Div(tmpYLocal, tmpALocal, tmpYLocal, tileLength);
        PipeBarrier<PIPE_V>();

        LocalTensor<inType> bLocal = inQueueB.template DeQue<inType>();
        LocalTensor<float> tmpBLocal = sharedTempBuf.Get<float>();
        if constexpr (sizeof(inType) == sizeof(float)) {
            DataCopy(tmpBLocal, bLocal, tileLength);
        } else {
            Cast(tmpBLocal, bLocal, RoundMode::CAST_NONE, tileLength);
        }

        inQueueB.template FreeTensor(bLocal);
        // PipeBarrier<PIPE_V>();
        Mul(tmpYLocal, tmpYLocal, tmpBLocal, tileLength);
        PipeBarrier<PIPE_V>();

        //  quant compute
        if constexpr (quantType == QuantType::STATIC_PER_CHANNEL) {
            ComputePerChannelQuant(offsetRow, smoothIndex);
        } else if constexpr (quantType == QuantType::STATIC_PER_TENSOR) {
            LocalTensor<float> smoothLocal = smoothQueue.AllocTensor<float>();
            LocalTensor<float> offsetsLocal = offsetsQueue.AllocTensor<float>();
            uint8_t rightPadding = alignedGroupLen - tilingData_.groupLen;
            DataCopyParams copyParams{ 1, (uint16_t)(tilingData_.groupLen * sizeof(int32_t)), 0, 0 };
            DataCopyPadParams padParams{ true, 0, rightPadding, 0 };
            DataCopyPad(smoothLocal, smooth_scales_Gm, copyParams, padParams);
            smoothQueue.EnQue(smoothLocal);
            DataCopyPad(offsetsLocal, offsetsGm, copyParams, padParams);
            offsetsQueue.EnQue(offsetsLocal);
            ComputePerTensorQuant(offsetRow, smoothIndex);
            smoothQueue.FreeTensor(smoothLocal);
            offsetsQueue.FreeTensor(offsetsLocal);
        }
        Duplicate<float>(scaleLocal, 0.0, basicRowLenCal);
        scaleQueue.EnQue<float>(scaleLocal);
    }

private:
    GlobalTensor<inType> xGm;
};
} // namespace SwiGluQuantOpt
#endif // SWI_GLU_QUANT_STATIC_H